
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>文章结构 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="Transformer 解读" href="Transformer%20%E8%A7%A3%E8%AF%BB.html" />
    <link rel="prev" title="文章结构" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.1%20ResNet.html">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第十章：常见代码解读
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Transformer%20%E8%A7%A3%E8%AF%BB.html">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ViT%E8%A7%A3%E8%AF%BB.html">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第十章/LSTM解读及实战.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第十章/LSTM解读及实战.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第十章/LSTM解读及实战.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#">
   文章结构
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#lstm">
   LSTM 理解
  </a>
  <ul class="visible nav section-nav flex-column">
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id2">
     门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     遗忘门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     记忆门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     状态更新
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id6">
     输出门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id7">
     模型总结
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id8">
   LSTM 实战
  </a>
  <ul class="visible nav section-nav flex-column">
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id9">
     实验说明
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id10">
     模型实现
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id11">
     超参数及参数说明
    </a>
    <ul class="nav section-nav flex-column">
     <li class="toc-h3 nav-item toc-entry">
      <a class="reference internal nav-link" href="#mylstm-nn-lstm">
       MyLSTM 与 nn.LSTM
      </a>
     </li>
     <li class="toc-h3 nav-item toc-entry">
      <a class="reference internal nav-link" href="#nn-rnn">
       nn.RNN
      </a>
     </li>
    </ul>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id12">
     实验结果
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id13">
   关于梯度问题
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>文章结构</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#">
   文章结构
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#lstm">
   LSTM 理解
  </a>
  <ul class="visible nav section-nav flex-column">
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id2">
     门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     遗忘门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     记忆门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     状态更新
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id6">
     输出门
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id7">
     模型总结
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id8">
   LSTM 实战
  </a>
  <ul class="visible nav section-nav flex-column">
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id9">
     实验说明
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id10">
     模型实现
    </a>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id11">
     超参数及参数说明
    </a>
    <ul class="nav section-nav flex-column">
     <li class="toc-h3 nav-item toc-entry">
      <a class="reference internal nav-link" href="#mylstm-nn-lstm">
       MyLSTM 与 nn.LSTM
      </a>
     </li>
     <li class="toc-h3 nav-item toc-entry">
      <a class="reference internal nav-link" href="#nn-rnn">
       nn.RNN
      </a>
     </li>
    </ul>
   </li>
   <li class="toc-h2 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id12">
     实验结果
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id13">
   关于梯度问题
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="id1">
<h1>文章结构<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h1>
<p>在<a class="reference internal" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html"><span class="doc std std-doc">RNN详解及其实战</span></a>中，我们简单讨论了为什么需要RNN这类模型、RNN的具体思路、RNN的简单实现等问题。同时，在文章结尾部分我们提到了RNN存在的梯度消失问题，及之后的一个解决方案：<strong>LSTM</strong>。因此，本篇文章主要结构如下：</p>
<ol class="simple">
<li><p>LSTM 理解及简单实现</p></li>
<li><p>LSTM 实战</p></li>
<li><p>经典 RNN 与 LSTM 对比</p></li>
<li><p>关于梯度消失</p></li>
</ol>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="lstm">
<h1>LSTM 理解<a class="headerlink" href="#lstm" title="永久链接至标题">#</a></h1>
<p>其实，我们将 LSTM 与 RNN 说成两个并不可取， LSTM 依然归属于 RNN 之下，相比于使用线性回归方式来处理序列问题， LSTM 其实是设计了一个模块来取代线性回归算法。</p>
<p>LSTM(Long Short-Term Memory)，翻译过来是长短期记忆法，其核心思想可以说非常的简单：既然 RNN 只能保存短期的记忆，那我增加一个长期记忆，不就可以解决这个问题了名？因此，LSTM提出了长期记忆和短期记忆，通过调整长期记忆和短期记忆之间的比例，来维持长期记忆的可靠，降低 RNN 的梯度消失问题。可以看到下方结构图中，模型输入由两个升级到三个，分别是当前节点状态 <span class="math notranslate nohighlight">\(\mathbf{X}_{t}\)</span>，长期记忆：<span class="math notranslate nohighlight">\(\mathbf{C}_{t-1}\)</span>，短期记忆 <span class="math notranslate nohighlight">\(\mathbf{H}_{t-1}\)</span>。输出状态依然是两个：节点当前状态 <span class="math notranslate nohighlight">\(\mathbf{C}_{t}\)</span>，和节点当前隐藏状态 <span class="math notranslate nohighlight">\(\mathbf{H}_{t}\)</span>。</p>
<p><img alt="LSTM结构图" src="../_images/LSTM-arch.jpg" /></p>
<p>那么问题来了， LSTM 是如何实现对长短记忆的控制呢？
这就不得不提众人所知的三个门：</p>
<ul class="simple">
<li><p>遗忘门：控制保留多少上一时刻的单元节点到当前节点</p></li>
<li><p>记忆门：控制将当前时刻的多少信息记忆到节点中</p></li>
<li><p>输出门：控制输出多少信息给当前输出</p></li>
</ul>
<p>我们在分析三个门之前，我们先了解 <strong>门</strong> 这一概念。</p>
<section id="id2">
<h2>门<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p>从简化图中可以看到， <strong>门</strong>的感觉类似于电路中的一个开关，当开关按下，信息通过，而开关抬起，信息不再通过。实际也如此类似，<strong>门</strong>是一个全连接层，输入为一个向量，输出为一个位于 [0,1] 之间的值。
我们来设计一个非常简单的遗忘门：每次学习状态之后，都遗忘一定的已学习内容，注意，这里的遗忘门与 LSTM 的遗忘门无关，单纯理解 <strong>门</strong> 这一概念。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span># 一个线性层 用来计算遗忘多少
gate_linear = nn.Linear(hidden_size, 1)
# 一个线性层 用来学习
study_linear = nn.Linear(hidden_size, hidden_size)
# 此刻 h_t 是上一时刻状态
# 输出为 0 - 1 的值
gate = gate_linear(h_t)
# h_t 经过 study_linear 进行学习
_h_t = study_linear(h_t)
# 在输出结果之前，经过 gate 导致内容受损，遗忘了一定的学习内容
h_t = gate * （_h_t）
</pre></div>
</div>
<p>可以看到，如果 <span class="math notranslate nohighlight">\(gate\)</span> 值为 0，则历史信息均会被遗忘，而如果值为1，则历史信息则会被完全保留，而 <code class="docutils literal notranslate"><span class="pre">gate_linear</span></code> 网络中的超参数会不断的学习，因此一个可以学习的开关门就出现了。</p>
<p>但是，<span class="math notranslate nohighlight">\(gate\)</span> 作为一个浮点型的数据，对于 临时结果矩阵变量 <span class="math notranslate nohighlight">\(\_h\_t\)</span> 而言，其遗忘控制是全局的，也就是，当 <span class="math notranslate nohighlight">\(gate\)</span> 为 0 时， 其最终结果 <span class="math notranslate nohighlight">\(h\_t\)</span> 为全 0 矩阵。因此我们应该注意： LSTM 中并不采用这样的大闸门，而是采用对每个变量进行分别控制的小水龙头(神经网络激活函数 <code class="docutils literal notranslate"><span class="pre">nn.Sigmode</span></code> )</p>
<p>而在 LSTM 中，门主要使用 <span class="math notranslate nohighlight">\(Sigmod\)</span> 神经网络(<strong>再次注意，并非是激活函数，而是 Sigmod 神经网络</strong>)来完成。</p>
<p>下方是一个示例代码：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span>hidden_size = 5
sigmoid = nn.Sigmoid()
# 隐藏状态 为了方便计算，假定全 1
hidden_emb = torch.ones(hidden_size, hidden_size)
# 中间某一层神经网络
model = nn.Linear(hidden_size，hidden_size)
# 获取该层输出,此时尚未被门限制
mid_out = model(hidden_emb)
# 获取一个门 -- 注意：并非一定由该变量所控制
# 比如：也可以由上一时刻的隐藏状态控制
# 代码为： gate = sigmoid(hidden_emb)
gate = sigmoid(mid_out) 
# 得到最终输出
final_out = gate * mid_out
</pre></div>
</div>
<p>在有了对门的基础知识后，接下来对遗忘门、记忆门、输出门进行分别分析。</p>
</section>
<section id="id3">
<h2>遗忘门<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h2>
<p>遗忘门涉及部分如下图所示：
<img alt="LSTM-遗忘门" src="../_images/LSTM-gate_f.jpg" /></p>
<p>其中，下方蓝色表示三个门共用的输入部分，均为 [<span class="math notranslate nohighlight">\(\mathbf{h}_{t-1}\)</span>,<span class="math notranslate nohighlight">\(\mathbf{X}_{t}\)</span>],需要注意，这里由于三个门之间并不共享权重参数，因此公示虽然接近，但是一共计算了三次，遗忘门被标记为 <span class="math notranslate nohighlight">\(f_t\)</span>, 列出遗忘门公式为：
$<span class="math notranslate nohighlight">\(
f_t = \sigma(\mathbf{W_f} * [\mathbf{h}_{t-1},\mathbf{X}_{t}]  + \mathbf{b_f})
\)</span><span class="math notranslate nohighlight">\(
输出结果为取值范围为 [ 0, 1 ] 的矩阵，主要功能是控制与之相乘的矩阵的**遗忘程度**。
将 \)</span>f_t<span class="math notranslate nohighlight">\( 与输入的上一长期状态 \)</span>C_{t-1}<span class="math notranslate nohighlight">\( 相乘：
\)</span><span class="math notranslate nohighlight">\(
C_t' = f_t * C_{t-1}
\)</span>$</p>
<p>一部分的 <span class="math notranslate nohighlight">\(C_{t-1}\)</span> 就这样被遗忘了。</p>
</section>
<section id="id4">
<h2>记忆门<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
<p>记忆门涉及部分如下所示：
<img alt="LSTM记忆门" src="../_images/LSTM-gate_m.jpg" /></p>
<p>从图中可以看到，记忆门中相乘的两个部分均由 <span class="math notranslate nohighlight">\(\mathbf{h}_{t-1}\)</span> 与 <span class="math notranslate nohighlight">\(\mathbf{X}_{t}\)</span> 得到，
其中，左侧控制记忆多少的部分，与遗忘门公式基本一致：
$<span class="math notranslate nohighlight">\(
i_t = \sigma(\mathbf{W_i} * [\mathbf{h}_{t-1},\mathbf{X}_{t}]  + \mathbf{b_i})
\)</span><span class="math notranslate nohighlight">\(
与遗忘门相通，输出结果为取值范围为 [ 0, 1 ] 的矩阵，主要功能是控制与之相乘的矩阵的**记忆程度**。
而右侧，则更换了激活函数，由 \)</span>sigmoid<span class="math notranslate nohighlight">\( 变成了 \)</span>tanh<span class="math notranslate nohighlight">\(：
\)</span><span class="math notranslate nohighlight">\(
\tilde{C_t} = \tanh(\mathbf{W_c} * [\mathbf{h}_{t-1},\mathbf{X}_{t}]  + \mathbf{b_c})
\)</span>$
该公式负责的部分可以看做负责<strong>短期隐藏状态</strong>的更新，取值范围为 [ -1, 1 ]。</p>
<p>最终记忆门更新公式如下:
$<span class="math notranslate nohighlight">\(
\tilde{C_t'}=  i_t * \tilde{C_t}
\)</span>$</p>
<p>我们可以说  <span class="math notranslate nohighlight">\(\tilde{C_t'}\)</span> 是保留了一定内容的短期状态</p>
</section>
<section id="id5">
<h2>状态更新<a class="headerlink" href="#id5" title="永久链接至标题">#</a></h2>
<p><img alt="LSTM-状态更新" src="../_images/LSTM-update.jpg" /></p>
<p>在通过遗忘门获取到了被遗忘一定内容的长期状态 <span class="math notranslate nohighlight">\(C_t'\)</span> 和 保留了一定内容的短期状态 <span class="math notranslate nohighlight">\(\tilde{C_t'}\)</span> 之后，可以通过加法直接结合</p>
<div class="math notranslate nohighlight">
\[
C_t =  C_t' + \tilde{C_t'}
\]</div>
</section>
<section id="id6">
<h2>输出门<a class="headerlink" href="#id6" title="永久链接至标题">#</a></h2>
<p><img alt="LSTM输出门" src="../_images/LSTM-gate_o.jpg" /></p>
<p>输出门是三个门中最后一个门，当数据到达这里的时候，我们主要控制将长期状态中的内容 <span class="math notranslate nohighlight">\(C_t\)</span> 保存一定内容到 <span class="math notranslate nohighlight">\(h_t\)</span> 中，这里不再赘述
$<span class="math notranslate nohighlight">\(
o_t = \sigma(\mathbf{W_o} * [\mathbf{h}_{t-1},\mathbf{X}_{t}]  + \mathbf{b_o})
\)</span>$</p>
<div class="math notranslate nohighlight">
\[
h_t = o_t * \tanh(C_t)
\]</div>
</section>
<section id="id7">
<h2>模型总结<a class="headerlink" href="#id7" title="永久链接至标题">#</a></h2>
<p>可以看到，所有公式的核心部分都是如此的相似：
$<span class="math notranslate nohighlight">\(
\mathbf{W_c} * [\mathbf{h}_{t-1},\mathbf{X}_{t}]  + \mathbf{b_c}
\)</span>$
而这部分其实又只是简单的线性函数，所以 LSTM 比 RNN 高级的地方其实并不在于某一条公式，而是它调整了数据之间的流动，按照一定的比例进行融合，弱化了长距离下的梯度消失问题。</p>
<p>最后总的来看，LSTM 其实就是一个升级版本的的 RNN，他额外初始化了一个状态 <span class="math notranslate nohighlight">\(C\)</span>， 用来保存长期的记忆，控制远距离上的参数权重。而输出也基本类似于此。</p>
</section>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="id8">
<h1>LSTM 实战<a class="headerlink" href="#id8" title="永久链接至标题">#</a></h1>
<section id="id9">
<h2>实验说明<a class="headerlink" href="#id9" title="永久链接至标题">#</a></h2>
<p>完整代码实现可以点击<span class="xref myst">这里</span>下载。在完整代码中，我们共计使用了三个模型并对比了他们的效果，三个模型分别是：由我完全使用 <code class="docutils literal notranslate"><span class="pre">nn.Linear</span></code> 实现的 LSTM 模型、 使用 <code class="docutils literal notranslate"><span class="pre">nn.LSTM</span></code> 为基础的 LSTM 模型和使用 <code class="docutils literal notranslate"><span class="pre">nn.RNN</span></code> 为基础实现的 RNN 模型。</p>
<p>实验数据集采用 <a class="reference external" href="http://ai.stanford.edu/~amaas/data/sentiment/">IMDB 数据集</a>。主要由电影评论构成，长度不均，<strong>但是长度在 1000 左右的数据属于常见数据</strong>。数据集样本均衡，数共计 50000 个样本，训练和测试各有 25000 个样本，同时训练和测试的正负比例均为 1:1。</p>
<p>根据我们对 RNN 的了解，这样的长度是很难学习到有效的知识的，所以很适合比较 RNN 与 LSTM 之间的区别。</p>
<p>为了方便代码复现，在实现中借助了 <code class="docutils literal notranslate"><span class="pre">torchtext</span></code> 来完成数据下载及加载。</p>
<p>为了证明模型真的有学习到一定的内容，所以对比实验中部分参数可能存在部分区别，可以在本地调整到同一参数进行细致的对比实验。</p>
</section>
<section id="id10">
<h2>模型实现<a class="headerlink" href="#id10" title="永久链接至标题">#</a></h2>
<p>我们在这里分析一下由我实现的 LSTM 模型，并以此了解 LSTM 模型。(ps:个人能力有限,没能实现 <code class="docutils literal notranslate"><span class="pre">num_layers</span></code> 和 <code class="docutils literal notranslate"><span class="pre">Bi-LSTM</span></code> 两个特点，此外可能实现存在其他问题，欢迎给予反馈)</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 定义基础模型</span>
<span class="k">class</span> <span class="nc">LSTM</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        args:</span>
<span class="sd">            input_size: 输入大小</span>
<span class="sd">            hidden_size: 隐藏层大小</span>
<span class="sd">            num_classes: 最后输出的类别，在这个示例中，输出应该是 0 或者 1</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">LSTM</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">input_size</span> <span class="o">=</span> <span class="n">input_size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc_i</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc_f</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc_g</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc_o</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">tanh</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Tanh</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc_out</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 初始化隐藏状态 -- 短期记忆</span>
        <span class="n">h_t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="c1"># 初始化隐藏状态 -- 长期记忆</span>
        <span class="n">c_t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
        <span class="c1"># 输入与短期记忆相拼接</span>
        <span class="n">combined</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">h_t</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
        <span class="c1"># 记忆门 -- 输出矩阵内容为 0-1 之间的数字</span>
        <span class="n">i_t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc_i</span><span class="p">(</span><span class="n">combined</span><span class="p">))</span>
        <span class="c1"># 遗忘门 -- 输出矩阵内容为 0-1 之间的数字</span>
        <span class="n">f_t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc_f</span><span class="p">(</span><span class="n">combined</span><span class="p">))</span>
        <span class="c1">#</span>
        <span class="n">g_t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc_g</span><span class="p">(</span><span class="n">combined</span><span class="p">))</span>
        <span class="c1">#  输出门 -- 输出矩阵内容为 0-1 之间的数字</span>
        <span class="n">o_t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fc_o</span><span class="p">(</span><span class="n">combined</span><span class="p">))</span>
        <span class="c1"># 长期状态 =  遗忘门 * 上一时刻的长期状态 + 记忆门* 当前记忆状态</span>
        <span class="n">c_t</span> <span class="o">=</span> <span class="n">f_t</span> <span class="o">*</span> <span class="n">c_t</span> <span class="o">+</span> <span class="n">i_t</span> <span class="o">*</span> <span class="n">g_t</span>
        <span class="c1"># 隐藏状态 = 输出门 * 长期状态</span>
        <span class="n">h_t</span> <span class="o">=</span> <span class="n">o_t</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">c_t</span><span class="p">)</span>
        <span class="c1"># 降维操作 </span>
        <span class="n">h_t</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">avg_pool2d</span><span class="p">(</span><span class="n">h_t</span><span class="p">,</span> <span class="p">(</span><span class="n">h_t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
        <span class="c1"># </span>
        <span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc_out</span><span class="p">(</span><span class="n">h_t</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">out</span> 
</pre></div>
</div>
</section>
<section id="id11">
<h2>超参数及参数说明<a class="headerlink" href="#id11" title="永久链接至标题">#</a></h2>
<section id="mylstm-nn-lstm">
<h3>MyLSTM 与 nn.LSTM<a class="headerlink" href="#mylstm-nn-lstm" title="永久链接至标题">#</a></h3>
<table class="colwidths-auto table">
<thead>
<tr class="row-odd"><th class="head"><p>名称</p></th>
<th class="head"><p>值</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>learning_rate</p></td>
<td><p>0.001</p></td>
</tr>
<tr class="row-odd"><td><p>batch_size</p></td>
<td><p>32</p></td>
</tr>
<tr class="row-even"><td><p>epoch</p></td>
<td><p>6(3)</p></td>
</tr>
<tr class="row-odd"><td><p>input_size</p></td>
<td><p>64</p></td>
</tr>
<tr class="row-even"><td><p>hidden_size</p></td>
<td><p>128</p></td>
</tr>
<tr class="row-odd"><td><p>num_classes</p></td>
<td><p>2</p></td>
</tr>
</tbody>
</table>
<p>此时：
MyLSTM 参数量: 99074
nn.LSTM 参数量: 99328</p>
<p>由于我实现的 MyLSTM 与 nn.LSTM 有 254 的参数差，我本人并没能分析出来差别。 <code class="docutils literal notranslate"><span class="pre">nn.LSTM</span></code> 在实验时大概率比我的 MyLSTM 迭代更快，所以容易较早的过拟合，所以将其训练 epoch 砍半，也就是说 MyLSTM 使用 6 epoch 进行训练，而 <code class="docutils literal notranslate"><span class="pre">nn.LSTM</span></code> 使用 3 epoch 进行训练。两者可以达到基本相近的效果</p>
<p>另外在代码实现中 <code class="docutils literal notranslate"><span class="pre">nn.LSTM</span></code> 后面加了一个 <code class="docutils literal notranslate"><span class="pre">nn.Linear</span></code> 来实现二分类，参数量为 258， 所以 MyLSTM 和 LSTM 相差参数总量为 512。</p>
</section>
<section id="nn-rnn">
<h3>nn.RNN<a class="headerlink" href="#nn-rnn" title="永久链接至标题">#</a></h3>
<table class="colwidths-auto table">
<thead>
<tr class="row-odd"><th class="head"><p>名称</p></th>
<th class="head"><p>值</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>learning_rate</p></td>
<td><p><strong>0.0001</strong></p></td>
</tr>
<tr class="row-odd"><td><p>batch_size</p></td>
<td><p>32</p></td>
</tr>
<tr class="row-even"><td><p>epoch</p></td>
<td><p><strong>12-18</strong></p></td>
</tr>
<tr class="row-odd"><td><p>input_size</p></td>
<td><p>64</p></td>
</tr>
<tr class="row-even"><td><p>hidden_size</p></td>
<td><p>128</p></td>
</tr>
<tr class="row-odd"><td><p>num_classes</p></td>
<td><p>2</p></td>
</tr>
</tbody>
</table>
<p>此时：
nn.RNN 参数量: 25090</p>
<p>由于实验样本长度在 1000 上下， RNN 显示出来了极大的不稳定性，其中， 相较于 LSTM 更容易梯度爆炸、训练 epoch 更多、学习率需要调低等等问题，尽管如此依然不能保证稳定的良好结果。</p>
<p>举例来说，某学生学习阅读理解，要求根据文章内容回答文章的情感倾向，但是学生只喜欢看最后一句话，每次都根据最后一句话来回答问题，那么他基本上是等于瞎猜的，只能学到一点浅薄的知识。</p>
</section>
</section>
<section id="id12">
<h2>实验结果<a class="headerlink" href="#id12" title="永久链接至标题">#</a></h2>
<table class="colwidths-auto table">
<thead>
<tr class="row-odd"><th class="head"><p>MyLSTM</p></th>
<th class="head"><p>nn.LSTM</p></th>
<th class="head"><p>nn.RNN</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>0.86</p></td>
<td><p>0.80</p></td>
<td><p>0.67</p></td>
</tr>
</tbody>
</table>
</section>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="id13">
<h1>关于梯度问题<a class="headerlink" href="#id13" title="永久链接至标题">#</a></h1>
<ul class="simple">
<li><p>RNN问题中，<strong>总的梯度是不会消失的</strong>。即便梯度越传越弱，那也是远处的梯度逐渐消失，而近距离的梯度不会消失，因此，梯度总和不会消失。RNN 梯度消失的真正含义是：梯度被近距离梯度所主导，导致模型难以学到远距离的依赖关系。</p></li>
<li><p>LSTM 上有多条信息流路径，其中，<strong>元素相加的路径的梯度流是最稳定的</strong>，而其他路径上与基本的 RNN 相类似，依然存在反复相乘问题。</p></li>
<li><p>LSTM 刚刚提出时不存在遗忘门。这时候历史数据可以在这条路径上无损的传递，可以将其视为一条 <strong>高速公路</strong>，类似于 ResNet 中的残差连接。</p></li>
<li><p>但是其他路径上， LSTM 与 RNN 并无太多区别，依然会爆炸或者消失。由于<strong>总的远距离梯度 = 各个路径的远距离梯度之和</strong>，因此只要有一条路的远距离梯度没有消失，总的远距离梯度就不会消失。可以说，LSTM 通过这一条路拯救了总的远距离梯度。</p></li>
<li><p>同样，<strong>总的远距离梯度 = 各个路径的远距离梯度之和</strong>，虽然高速路上的梯度流比较稳定，但是其他路上依然存在梯度消失和梯度爆炸问题。因此，总的远距离梯度 = 正常梯度 + 爆炸梯度 = 爆炸梯度，因此 LSTM 依然存在梯度爆炸问题。 但是由于 LSTM 的道路相比经典 RNN 来说非常崎岖， 存在多次激活函数，因此 LSTM 发生梯度爆炸的概率要小得多。实践中通常通过梯度剪裁来优化问题。</p></li>
</ul>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">文章结构</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="Transformer%20%E8%A7%A3%E8%AF%BB.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">Transformer 解读</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>